Fix dynamic dag_id resolution in TriggerDagRunOperator links#56973
Fix dynamic dag_id resolution in TriggerDagRunOperator links#56973hwang-cadent wants to merge 2 commits intoapache:mainfrom
TriggerDagRunOperator links#56973Conversation
|
Congratulations on your first Pull Request and welcome to the Apache Airflow community! If you have any issues or are unsure about any anything please check our Contributors' Guide (https://github.com/apache/airflow/blob/main/contributing-docs/README.rst)
|
providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py
Outdated
Show resolved
Hide resolved
TriggerDagRunOperator links
76a3ed7 to
331face
Compare
bfbe526 to
735f5af
Compare
d1598cd to
2c20224
Compare
|
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 5 days if no further activity occurs. Thank you for your contributions. |
0750c6b to
7f284d3
Compare
897bf65 to
aba1fdf
Compare
pierrejeambrun
left a comment
There was a problem hiding this comment.
Something probably went wrong in the rebase process. The change set is huge.
Do you mind fixing the branch please so we can move forward.
74a653b to
f1c30c3
Compare
Rebased on the latest apache/main and cleaned up. The change set now includes only the 4 PR-related files). Ready for review. Thank you for your patience. |
dabla
left a comment
There was a problem hiding this comment.
Looking good to me reviewed it from mobile though so I might have missed something.
fdeb33c to
26ccfed
Compare
| elif field_name == "resources": | ||
| return Resources.from_dict(value) if value is not None else None | ||
| elif field_name.endswith("_date"): | ||
| # Check if value is ARG_NOT_SET before trying to deserialize as datetime |
There was a problem hiding this comment.
Is this change tested in test_serialized_objects?
There was a problem hiding this comment.
No, this change is not directly tested in test_serialized_objects.py. It's currently tested indirectly through TriggerDagRunOperator tests, which verify the behavior when logical_date is NOTSET.
A direct unit test for OperatorSerialization._deserialize_field_value() handling ARG_NOT_SET for date fields would be clearer and more focused. Should I add a direct test in test_serialized_objects.py?
dabla
left a comment
There was a problem hiding this comment.
Looking good to me apart from question if this change should also not be directly tested in test_serialized_objects instead of relying on indirect tests?
Good point. Currently tested indirectly through TriggerDagRunOperator tests. A direct unit test in test_serialized_objects.py for _deserialize_field_value() handling ARG_NOT_SET for date fields would be clearer. Should I add this test, or would you prefer to handle it in a follow-up? |
0534865 to
47803ea
Compare
- Add XCOM_DAG_ID constant to store resolved dag_id in XCom - Update TriggerDagRunLink.get_link() to check XCom first for dynamic dag_ids - Store resolved dag_id in XCom during execution for both Airflow 2.x and 3.x - Add comprehensive tests for dynamic dag_id link generation - Maintain backward compatibility with existing static dag_id usage - Fix deserialization of logical_date when it's NOTSET Fixes apache#46402 diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index db79e79..a9f1c3c770 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -1595,6 +1595,11 @@ class OperatorSerialization(DAGNode, BaseSerialization): elif field_name == "resources": return Resources.from_dict(value) if value is not None else None elif field_name.endswith("_date"): + # Check if value is ARG_NOT_SET before trying to deserialize as datetime + if isinstance(value, dict) and value.get(Encoding.TYPE) == DAT.ARG_NOT_SET: + from airflow.serialization.definitions.notset import NOTSET + + return NOTSET return cls._deserialize_datetime(value) if value is not None else None else: # For all other fields, return as-is (strings, ints, bools, etc.) diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index ae3f978..728a1cf 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -53,6 +53,7 @@ except ImportError: XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso" XCOM_RUN_ID = "trigger_run_id" +XCOM_DAG_ID = "trigger_dag_id" if TYPE_CHECKING: @@ -85,21 +86,26 @@ class TriggerDagRunLink(BaseOperatorLink): if TYPE_CHECKING: assert isinstance(operator, TriggerDagRunOperator) - trigger_dag_id = operator.trigger_dag_id - if not AIRFLOW_V_3_0_PLUS: - from airflow.models.renderedtifields import RenderedTaskInstanceFields - from airflow.models.taskinstancekey import TaskInstanceKey as CoreTaskInstanceKey - - core_ti_key = CoreTaskInstanceKey( - dag_id=ti_key.dag_id, - task_id=ti_key.task_id, - run_id=ti_key.run_id, - try_number=ti_key.try_number, - map_index=ti_key.map_index, - ) + # Try to get the resolved dag_id from XCom first (for dynamic dag_ids) + trigger_dag_id = XCom.get_value(ti_key=ti_key, key=XCOM_DAG_ID) + + # Fallback to operator attribute and rendered fields if not in XCom + if not trigger_dag_id: + trigger_dag_id = operator.trigger_dag_id + if not AIRFLOW_V_3_0_PLUS: + from airflow.models.renderedtifields import RenderedTaskInstanceFields + from airflow.models.taskinstancekey import TaskInstanceKey as CoreTaskInstanceKey + + core_ti_key = CoreTaskInstanceKey( + dag_id=ti_key.dag_id, + task_id=ti_key.task_id, + run_id=ti_key.run_id, + try_number=ti_key.try_number, + map_index=ti_key.map_index, + ) - if template_fields := RenderedTaskInstanceFields.get_templated_fields(core_ti_key): - trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef] + if template_fields := RenderedTaskInstanceFields.get_templated_fields(core_ti_key): + trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef] # Fetch the correct dag_run_id for the triggerED dag which is # stored in xcom during execution of the triggerING task. @@ -203,7 +209,7 @@ class TriggerDagRunOperator(BaseOperator): self.openlineage_inject_parent_info = openlineage_inject_parent_info self.deferrable = deferrable self.logical_date = logical_date - if logical_date is NOTSET: + if isinstance(logical_date, ArgNotSet) or logical_date is NOTSET: self.logical_date = NOTSET elif logical_date is None or isinstance(logical_date, (str, datetime.datetime)): self.logical_date = logical_date @@ -216,7 +222,7 @@ class TriggerDagRunOperator(BaseOperator): raise NotImplementedError("Setting `fail_when_dag_is_paused` not yet supported for Airflow 3.x") def execute(self, context: Context): - if self.logical_date is NOTSET: + if isinstance(self.logical_date, ArgNotSet) or self.logical_date is NOTSET: # If no logical_date is provided we will set utcnow() parsed_logical_date = timezone.utcnow() elif self.logical_date is None or isinstance(self.logical_date, datetime.datetime): @@ -274,6 +280,14 @@ class TriggerDagRunOperator(BaseOperator): def _trigger_dag_af_3(self, context, run_id, parsed_logical_date): from airflow.providers.common.compat.sdk import DagRunTriggerException + # Store the resolved dag_id to XCom for use in the link generation + # This is important for dynamic dag_ids (from XCom or complex templates) + # In Airflow 3.x, context has both "task_instance" and "ti" keys + if "task_instance" in context: + context["task_instance"].xcom_push(key=XCOM_DAG_ID, value=self.trigger_dag_id) + elif "ti" in context: + context["ti"].xcom_push(key=XCOM_DAG_ID, value=self.trigger_dag_id) + raise DagRunTriggerException( trigger_dag_id=self.trigger_dag_id, dag_run_id=run_id, @@ -319,10 +333,11 @@ class TriggerDagRunOperator(BaseOperator): raise e if dag_run is None: raise RuntimeError("The dag_run should be set here!") - # Store the run id from the dag run (either created or found above) to + # Store the run id and dag_id from the dag run (either created or found above) to # be used when creating the extra link on the webserver. ti = context["task_instance"] ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id) + ti.xcom_push(key=XCOM_DAG_ID, value=self.trigger_dag_id) if self.wait_for_completion: # Kick off the deferral process diff --git a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py index 0f8d171..920f38b 100644 --- a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py +++ b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py @@ -140,8 +140,10 @@ class TestDagRunOperator: assert task.trigger_run_id == expected_run_id # run_id is saved as attribute @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") - @mock.patch(f"{TRIGGER_OP_PATH}.XCom.get_one") - def test_extra_operator_link(self, mock_xcom_get_one, dag_maker): + @mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_value") + def test_extra_operator_link(self, mock_xcom_get_value, dag_maker): + from airflow.providers.standard.operators.trigger_dagrun import XCOM_RUN_ID + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): task = TriggerDagRunOperator( task_id="test_task", @@ -153,7 +155,13 @@ class TestDagRunOperator: dr = dag_maker.create_dagrun(run_id="test_run_id") ti = dr.get_task_instance(task_id=task.task_id) - mock_xcom_get_one.return_value = ti.run_id + # Mock XCom.get_value to return None for dag_id but return run_id for XCOM_RUN_ID + def mock_get_value(ti_key, key): + if key == XCOM_RUN_ID: + return "test_run_id" + return None + + mock_xcom_get_value.side_effect = mock_get_value link = task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key) @@ -161,6 +169,72 @@ class TestDagRunOperator: expected_url = f"{base_url}dags/{TRIGGERED_DAG_ID}/runs/test_run_id" assert link == expected_url, f"Expected {expected_url}, but got {link}" + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_value") + def test_extra_operator_link_with_dynamic_dag_id(self, mock_xcom_get_value, dag_maker): + """Test that operator link works correctly when dag_id is dynamically resolved from XCom.""" + from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + # In real scenario, this would be a template like "{{ ti.xcom_pull(...) }}" + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="test_run_id", + ) + + dr = dag_maker.create_dagrun(run_id="test_run_id") + ti = dr.get_task_instance(task_id=task.task_id) + + # Mock XCom.get_value to return our test values + def mock_get_value(ti_key, key): + if key == XCOM_DAG_ID: + return "dynamic_dag_id" + if key == XCOM_RUN_ID: + return "dynamic_run_id" + return None + + mock_xcom_get_value.side_effect = mock_get_value + + link = task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key) + + base_url = conf.get("api", "base_url", fallback="/").lower() + # Should use the dag_id from XCom, not the operator attribute + expected_url = f"{base_url}dags/dynamic_dag_id/runs/dynamic_run_id" + assert link == expected_url, f"Expected {expected_url}, but got {link}" + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + def test_trigger_dagrun_pushes_dag_id_to_xcom(self, dag_maker): + """Test that TriggerDagRunOperator pushes the resolved dag_id to XCom during execution.""" + from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + ) + + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance(task_id=task.task_id) + + # Create a mock task instance that stores XCom values + xcom_values = {} + + def mock_xcom_push(key, value, **kwargs): + xcom_values[key] = value + + ti.xcom_push = mock_xcom_push + + # Execute the task (will raise exception in AF3, but should push XCom first) + try: + task.execute(context={"task_instance": ti}) + except DagRunTriggerException: + pass # Expected in Airflow 3 + + # Verify that the dag_id was pushed to XCom + assert XCOM_DAG_ID in xcom_values + assert xcom_values[XCOM_DAG_ID] == TRIGGERED_DAG_ID + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") def test_trigger_dagrun_custom_run_id(self): task = TriggerDagRunOperator( @@ -577,8 +651,37 @@ class TestDagRunOperatorAF2: assert task.trigger_run_id == "test_run_id" - def test_extra_operator_link(self, dag_maker, session): + def test_trigger_dagrun_pushes_dag_id_to_xcom(self, dag_maker, session): + """Test that TriggerDagRunOperator pushes the resolved dag_id to XCom during execution.""" + from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="test_run_id", + ) + dag_maker.create_dagrun() + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + triggering_ti = session.scalar( + select(TaskInstance).filter_by(task_id=task.task_id, dag_id=task.dag_id) + ) + assert triggering_ti is not None + + # Verify that the dag_id was pushed to XCom + dag_id_xcom = triggering_ti.xcom_pull(key=XCOM_DAG_ID) + assert dag_id_xcom == TRIGGERED_DAG_ID + + # Also verify run_id is still pushed + run_id_xcom = triggering_ti.xcom_pull(key=XCOM_RUN_ID) + assert run_id_xcom == "test_run_id" + + @mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_value") + def test_extra_operator_link(self, mock_xcom_get_value, dag_maker, session): """Asserts whether the correct extra links url will be created.""" + from airflow.providers.standard.operators.trigger_dagrun import XCOM_RUN_ID + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, trigger_run_id="test_run_id" @@ -587,13 +690,18 @@ class TestDagRunOperatorAF2: task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) triggering_ti = session.scalar( - select(TaskInstance).where( - TaskInstance.task_id == task.task_id, TaskInstance.dag_id == task.dag_id - ) + select(TaskInstance).filter_by(task_id=task.task_id, dag_id=task.dag_id) ) + # Mock XCom.get_value to return None for dag_id but return run_id for XCOM_RUN_ID + def mock_get_value(ti_key, key): + if key == XCOM_RUN_ID: + return "test_run_id" + return None + + mock_xcom_get_value.side_effect = mock_get_value + with mock.patch("airflow.utils.helpers.build_airflow_url_with_query") as mock_build_url: - # This is equivalent of a task run calling this and pushing to xcom task.operator_extra_links[0].get_link(operator=task, ti_key=triggering_ti.key) assert mock_build_url.called args, _ = mock_build_url.call_args @@ -603,6 +711,47 @@ class TestDagRunOperatorAF2: } assert expected_args in args + @mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_value") + def test_extra_operator_link_with_dynamic_dag_id(self, mock_xcom_get_value, dag_maker, session): + """Test that operator link works correctly when dag_id is dynamically resolved from XCom.""" + from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + # In real scenario, this would be a template like "{{ ti.xcom_pull(...) }}" + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="test_run_id", + ) + dag_maker.create_dagrun() + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + triggering_ti = session.scalar( + select(TaskInstance).filter_by(task_id=task.task_id, dag_id=task.dag_id) + ) + assert triggering_ti is not None + + # Mock XCom.get_value to return our test values + def mock_get_value(ti_key, key): + if key == XCOM_DAG_ID: + return "dynamic_dag_id" + if key == XCOM_RUN_ID: + return "test_run_id" + return None + + mock_xcom_get_value.side_effect = mock_get_value + + with mock.patch("airflow.utils.helpers.build_airflow_url_with_query") as mock_build_url: + task.operator_extra_links[0].get_link(operator=task, ti_key=triggering_ti.key) + assert mock_build_url.called + args, _ = mock_build_url.call_args + # Should use the dag_id from XCom, not the operator attribute + expected_args = { + "dag_id": "dynamic_dag_id", + "dag_run_id": "test_run_id", + } + assert expected_args in args + def test_trigger_dagrun_with_logical_date(self, dag_maker): """Test TriggerDagRunOperator with custom logical_date.""" custom_logical_date = timezone.datetime(2021, 1, 2, 3, 4, 5) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index fc832e3..5574e30 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -4044,7 +4044,17 @@ class TestTriggerDagRunOperator: expected_calls = [ mock.call.send( - msg=TriggerDagRun( + SetXCom( + key="trigger_dag_id", + value="test_dag", + dag_id="test_handle_trigger_dag_run", + task_id="test_task", + run_id="test_run", + map_index=-1, + ), + ), + mock.call.send( + TriggerDagRun( dag_id="test_dag", run_id="test_run_id", reset_dag_run=False, @@ -4052,7 +4062,7 @@ class TestTriggerDagRunOperator: ), ), mock.call.send( - msg=SetXCom( + SetXCom( key="trigger_run_id", value="test_run_id", dag_id="test_handle_trigger_dag_run", @@ -4166,38 +4176,47 @@ class TestTriggerDagRunOperator: assert state == expected_task_state assert msg.state == expected_task_state - expected_calls = [ - mock.call.send( - msg=TriggerDagRun( - dag_id="test_dag", - run_id="test_run_id", - logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), - ), - ), - mock.call.send( - msg=SetXCom( - key="trigger_run_id", - value="test_run_id", - dag_id="test_handle_trigger_dag_run_wait_for_completion", - task_id="test_task", - run_id="test_run", - map_index=-1, - ), + # Verify the expected calls were made (order may vary due to SetRenderedFields) + # Check each expected call individually since SetRenderedFields appears first + mock_supervisor_comms.send.assert_any_call( + SetXCom( + key="trigger_dag_id", + value="test_dag", + dag_id="test_handle_trigger_dag_run_wait_for_completion", + task_id="test_task", + run_id="test_run", + map_index=-1, ), - mock.call.send( - msg=GetDagRunState( - dag_id="test_dag", - run_id="test_run_id", - ), + ) + mock_supervisor_comms.send.assert_any_call( + TriggerDagRun( + dag_id="test_dag", + run_id="test_run_id", + logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ), - mock.call.send( - msg=GetDagRunState( - dag_id="test_dag", - run_id="test_run_id", - ), + ) + mock_supervisor_comms.send.assert_any_call( + SetXCom( + key="trigger_run_id", + value="test_run_id", + dag_id="test_handle_trigger_dag_run_wait_for_completion", + task_id="test_task", + run_id="test_run", + map_index=-1, ), + ) + # Verify GetDagRunState was called at least once (may be called multiple times during polling) + get_dag_run_state_calls = [ + call_args + for call_args in mock_supervisor_comms.send.call_args_list + if len(call_args.args) > 0 + and isinstance(call_args.args[0], GetDagRunState) + and call_args.args[0].dag_id == "test_dag" + and call_args.args[0].run_id == "test_run_id" ] - mock_supervisor_comms.assert_has_calls(expected_calls) + assert len(get_dag_run_state_calls) >= 1, ( + f"Expected at least 1 GetDagRunState call, got {len(get_dag_run_state_calls)}" + ) @pytest.mark.parametrize( ("allowed_states", "failed_states", "intermediate_state"),
47803ea to
c46395b
Compare
Fixes #46402 - Dynamic dag_id resolution in TriggerDagRunOperator links
Description
This PR addresses the issue where
TriggerDagRunOperatorlinks fail to work correctly when thedag_idis dynamically determined at runtime (e.g., from XCom values or complex templates). The current implementation only uses the statictrigger_dag_idattribute, which doesn't reflect the actual resolved dag_id when it's determined dynamically.Changes
XCOM_DAG_IDconstant to store the resolved dag_id in XCom during task executionTriggerDagRunLink.get_link()to prioritize XCom values over static operator attributes_trigger_dag_af_2) and 3.x (_trigger_dag_af_3) execution paths to push the resolved dag_id to XComTesting
Type of Change
^ Add meaningful description above
Read the Pull Request Guidelines for more information.
In case of fundamental code changes, an Airflow Improvement Proposal (AIP) is needed.
In case of a new dependency, check compliance with the ASF 3rd Party License Policy.
In case of backwards incompatible changes please leave a note in a newsfragment file, named
{pr_number}.significant.rstor{issue_number}.significant.rst, in airflow-core/newsfragments.